import stable_baselines3
import pprint
import numpy as np
import cv2
import gym
import highway_env
from stable_baselines3 import DQN, DDPG
from sb3_contrib import TRPO

from highway_env.vehicle.kinematics import Performance, Logger


def printerr(input: str):
    print(input)
    
    
#Used for saving the model in a xey format --  FE 11000 iterations : 11e3
def float_to_e(f):
    s=str(int(f))
    output = ""
    count = 0
    for i in range(len(s)):
        if int(s[i]) != 0:
            output += count*"0" + s[i]
            count = 0
        else:
            count += 1
    output += "e" + f"{count}"
    return output


def models(situation: str, env, alg):
    if alg.upper() == "TRPO":
        model = TRPO("MlpPolicy", env,
                     learning_rate=0.00001,#0.001
                     n_steps=1024,
                     batch_size=128,
                     gamma=0.99,
                     cg_max_steps=15,
                     cg_damping=0.1,
                     line_search_shrinking_factor=0.8,
                     line_search_max_iter=10,
                     n_critic_updates=10,
                     gae_lambda=0.95,
                     use_sde=False,
                     sde_sample_freq=-1,
                     normalize_advantage=True,
                     target_kl=0.01,
                     sub_sampling_factor=1,
                     policy_kwargs=None,
                     verbose=1,
                     tensorboard_log=f"{situation}_TRPO/",
                     seed=None,
                     device='cuda',
                     _init_setup_model=True)
        
    if alg.upper() == "DQN":
        model = DQN('MlpPolicy', env,
                    policy_kwargs=dict(net_arch=[256, 256]),
                    learning_rate=5e-4,
                    buffer_size=15000,
                    learning_starts=200,
                    batch_size=32,
                    gamma=0.8,
                    train_freq=1,
                    gradient_steps=1,
                    target_update_interval=50,
                    verbose=1,
                    tensorboard_log=f"{situation}_DQN/")
    return model


def learn(situation: str, alg: str, new_model, iterations, load_path):

    env = gym.make(situation)
    
    if alg.upper() == "TRPO":
        env = gym.make(situation)
        env.configure({
            'offroad_terminal': True,
            "screen_width": 1280,
            "screen_height": 560,
            "renderfps": 16,
            'simulation_frequency':15,
            'policy_frequency':15,
            'action': {'type': 'ContinuousAction'},
            'lateral': True,
            'longitudinal': True,
            "other_vehicles": 1, # non-ego vehicles
            'vehicles_count': 1
        })
          
    if alg.upper() == "DQN":
        env.configure({
            'offroad_terminal': True,
            "screen_width": 1280,
            "screen_height": 560,
            "renderfps": 16,
            'simulation_frequency':15,
            'policy_frequency':15,
            'other_vehicles': 1
        })
    
    env.reset()

    if new_model:
        model = models(situation, env, alg)
    else:
        if alg.upper() == "TRPO":
          model = TRPO.load(load_path)
        if alg.upper() == "DQN":
          model = DQN.load(load_path)
        model.set_env(env)

    for i, iter in enumerate(iterations):
        if i == 0:
            iter_round = iter
        else:
            iter_round = iter - iterations[i - 1]
            
        model.learn(int(iter_round))
        
        if new_model:
            save_path = "models/" + situation + "_" + alg + f"/{float_to_e(iter)}"
        else:
            save_path = "models/" + situation + "_" + alg + f"/{load_path.split('/')[-1]}+{float_to_e(iter)}"
        model.save(save_path)
        performace_test(env, model, save_path, i)
        print(f"\n Finished learning for round {iter} of {iterations} \n")


def performace_test(env, model, save_path, i):
    perfm = Performance()
    lolly = Logger()
    
    number_of_runs = 100
    for f in range(number_of_runs):
        done = truncated = False
        obs, info = env.reset()
        reward = 0
    
        ego_car = env.controlled_vehicles[0]
    
        stepcounter = 0
        
        while (not done) and ego_car.speed > 2 and stepcounter < 800:   #800     
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, done, truncated, info = env.step(action)
            stepcounter += 1
            lolly.file(ego_car)
    
        perfm.add_measurement(lolly)
        lolly.clear_log()

    print(perfm.print_performance())
    what = "w" if i == 0 else "a"
    with open(save_path + ".txt", what) as my_file:
        my_file.write(f"{perfm.string_rep()}")
        my_file.write(f"\n\n")
        
        
def optimize_reward(situation: str, alg: str):
    p = 0
    array_rw = [-1 for i in range(8)]
    for i in range(1,6,1):
        for j in range(1,6-i,1):
            k = 6 - i - j
            
            reward_weights= [i,j,k]
            
            env = gym.make(situation)
    
            if alg.upper() == "TRPO":
                env = gym.make(situation)
                env.configure({
                    'offroad_terminal': True,
                    "screen_width": 1280,
                    "screen_height": 560,
                    "renderfps": 16,
                    'simulation_frequency':15,
                    'policy_frequency':15,
                    'action': {'type': 'ContinuousAction'},
                    'lateral': True,
                    'longitudinal': True,
                    "other_vehicles": 1, # non-ego vehicles
                    'vehicles_count': 1,
                    'weights_array': reward_weights
                })
                  
            if alg.upper() == "DQN":
                env.configure({
                    'offroad_terminal': True,
                    "screen_width": 1280,
                    "screen_height": 560,
                    "renderfps": 16,
                    'simulation_frequency':15,
                    'policy_frequency':15,
                    'other_vehicles': 1,
                    'weights_array': reward_weights
                    
                })
            
            env.reset()
            model = models(situation, env, alg)
            
            model.learn(int(5e4))
            
            perfm = Performance()
            lolly = Logger()
          
            number_of_runs = 100
            for f in range(number_of_runs):
                done = truncated = False
                obs, info = env.reset()
                reward = 0
            
                ego_car = env.controlled_vehicles[0]
            
                stepcounter = 0
                
                while (not done) and ego_car.speed > 2 and stepcounter < 800:
                    action, _states = model.predict(obs, deterministic=True)
                    obs, reward, done, truncated, info = env.step(action)
                    stepcounter += 1
                    lolly.file(ego_car)
            
                perfm.add_measurement(lolly)
                lolly.clear_log()
            
            feat = perfm.array_rep()
            
            what = "w" if p == 0 else "a"
            with open("diff_weights.txt", what) as my_file:
                my_file.write(f"{reward_weights}")
                my_file.write(f"\n")
                my_file.write(f"{perfm.string_rep()}")
                my_file.write(f"---------------------------- \n\n")
            p += 1
            
            if array_rw[3] == -1:
                for i in range(8):
                    array_rw[i] = reward_weights
            else:
                if feat[0] > feat_back[0]:
                    array_rw[0] = reward_weights
                if feat[1] < feat_back[1]:
                    array_rw[1] = reward_weights
                if feat[2] < feat_back[2]:
                    array_rw[2] = reward_weights
                if feat[3] > feat_back[3]:
                    array_rw[3] = reward_weights
                if feat[4] < feat_back[4]:
                    array_rw[4] = reward_weights
                if feat[5] > feat_back[5]:
                    array_rw[5] = reward_weights
                if feat[6] > feat_back[6]:
                    array_rw[6] = reward_weights
                if feat[7] < feat_back[7]:
                    array_rw[7] = reward_weights
            
            feat_back = feat
    with open("diff_weights.txt", "a") as my_file:
        my_file.write(f"\_/-\_/-\_/-\_/-\_/-\_/-\_/-\_/-\_/-\_/-")
        my_file.write(f"{array_rw}")
        my_file.write(f"\_/-\_/-\_/-\_/-\_/-\_/-\_/-\_/-\_/-\_/-")

                
            
            